One example is that the symmetry your output has to be equal to or higher than the symmetry of your input. The following 3 simple tasks are to help demonstrate this:
We will see that we can quickly do Task 1, but not Task 2. Only by using symmetry breaking in Task 3 and Task 4 are we able to distort a square into a rectangle.
import torch
from functools import partial
import numpy as np
import e3nn
import e3nn.o3 as o3
from e3nn.point.operations import Convolution
from e3nn.non_linearities import GatedBlock
from e3nn.kernel import Kernel
from e3nn.kernel_mod import Kernel as KernelMod
from e3nn.radial import CosineBasisModel
from e3nn.non_linearities import rescaled_act
import matplotlib.pyplot as plt
%matplotlib inline
from e3nn.spherical_tensor import SphericalTensor
torch.set_default_dtype(torch.float64)
# Define out geometry
square = torch.tensor(
[[0., 0., 0.], [1., 0., 0.], [1., 1., 0.], [0., 1., 0.]]
)
square -= square.mean(-2)
sx, sy = 0.5, 1.5
rectangle = square * torch.tensor([sx, sy, 0.])
rectangle -= rectangle.mean(-2)
N, _ = square.shape
markersize = 15
def plot_task(ax, start, finish, title, marker=None):
ax.plot(torch.cat([start[:, 0], start[:, 0]]),
torch.cat([start[:, 1], start[:, 1]]), 'o-',
markersize=markersize + 5 if marker else markersize,
marker=marker if marker else 'o')
ax.plot(torch.cat([finish[:, 0], finish[:, 0]]),
torch.cat([finish[:, 1], finish[:, 1]]), 'o-', markersize=markersize)
for i in range(N):
ax.arrow(start[i, 0], start[i, 1],
finish[i, 0] - start[i, 0],
finish[i, 1] - start[i, 1],
length_includes_head=True, head_width=0.05, facecolor="black", zorder=100)
ax.set_title(title)
ax.set_axis_off()
# fig, axes = plt.subplots(1, 3, figsize=(14, 6))
fig, axes = plt.subplots(1, 2, figsize=(9, 6))
plot_task(axes[0], rectangle, square, "Task 1: Rectangle to Square")
plot_task(axes[1], square, rectangle, "Task 2: Square to Rectangle")
# plot_task(axes[2], square, rectangle, "Task 3: Square to Rectangle with Symmetry Breaking", "$\u2B2E$")
In these tasks, we want to move 4 points in one configuration to another configuration. The input to the network will be the initial geometry and features on that geometry. The output will be used to signify "displacement" of each point to the new configuration. We can represent displacement in a couple different ways. The simplest way is to represent a displacement as an L=1 vector, Rs=[(1, 1]]. However, to better illustrate the symmetry properties of the network, we instead are going to use a spherical harmonic signal or more specifically, the peak of the spherical harmonic signal, to signify the displacement of the original point.
First, we set up a very basic network that has the same representation list Rs = [(1, L) for L in range(5 + 1)] throughout the entire network. The input will be a spherical tensor with representation Rs and the output will also be a spherical tensor with representation Rs. We will interpret the output of the network as a spherical harmonic signal where the peak location will signify the desired displacement.
e3nn.networks.GatedConvNetwork class for our model¶from e3nn.networks import GatedConvNetwork
L_max = 5
Rs = [(1, L) for L in range(L_max + 1)]
Network = partial(GatedConvNetwork, Rs_in=Rs, Rs_hidden=Rs, Rs_out=Rs, lmax=L_max, max_radius=3.0, kernel=KernelMod)
In this task, our input is a four points in the shape of a rectangle with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (more symmetric) square.
model = Network()
params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1. # batch, point, channel
displacements = square - rectangle
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
iterations = 100
for i in range(iterations):
output = model(input, rectangle.unsqueeze(0))
loss = loss_fn(output, projections.unsqueeze(0))
if i % 10 == 0:
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Plot spherical harmonic projections
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
def plot_output(start, finish, output, start_label, finish_label):
rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)
fig.add_trace(go.Scatter3d(x=start[:, 0], y=start[:, 1], z=start[:, 2], mode="markers", name=start_label))
fig.add_trace(go.Scatter3d(x=finish[:, 0], y=finish[:, 1], z=finish[:, 2], mode="markers", name=finish_label))
for i in range(N):
r, f = SphericalTensor(output[0][i].detach(), 1, L_max).plot(center=start[i])
trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy())
trace.showscale = False
fig.add_trace(trace, 1, 1)
return fig
output = model(input, rectangle.unsqueeze(0))
fig = plot_output(rectangle, square, output, "Rectangle", "Square")
fig.update_layout(scene_aspectmode='data')
fig.show()
angles = torch.rand(3) * torch.tensor([np.pi, 2 * np.pi, np.pi])
rot = o3.rot(*angles)
rot_rectangle = torch.einsum('xy,jy->jx', (rot, rectangle))
rot_square = torch.einsum('xy,jy->jx', (rot, square))
output = model(input, rot_rectangle.unsqueeze(0))
fig = plot_output(rot_rectangle, rot_square, output, "Rectangle", "Square")
fig.update_layout(scene_aspectmode='data')
fig.show()
In this task, our input is a four points in the shape of a square with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?
model = Network()
params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1. # batch, point, channel
displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
iterations = 100
for i in range(iterations):
output = model(input, square.unsqueeze(0))
loss = loss_fn(output, projections.unsqueeze(0))
if i % 10 == 0:
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iterations = 100
for i in range(iterations):
output = model(input, square.unsqueeze(0))
loss = loss_fn(output, projections.unsqueeze(0))
if i % 10 == 0:
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.update_layout(scene_aspectmode='data')
fig.show()
To be able to do this task, you need to give the network more information -- information that breaks the symmetry to that of the desired output. The square has a point group of $D_{4h}$ (16 elements) while the rectangle has a point group of $D_{2h}$ (8 elements).
In this example, we are NOT using a network equivariant to parity) -- that will be in another update / tutorial -- so we are actually only sensitive to the fact that the square has $C_4$ symmetry while the rectangle has $C_2$ symmetry.
In this task, our input is four points in the shape of a square with simple scalars (1.0) AND a contribution for the $x^2 - y^2$ feature at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?
model = Network()
params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1. # batch, point, channel
# Breaking x and y symmetry with x^2 - y^2 component
input[:, :, 8] = 0.1 # x^2 - y^2
displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
iterations = 100
for i in range(iterations):
output = model(input, square.unsqueeze(0))
loss = loss_fn(output, projections.unsqueeze(0))
if i % 10 == 0:
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.update_layout(scene_aspectmode='data')
fig.show()
Notice how the shape below is an ellisoid elongated in the y direction and squished in the x. This isn't the only pertubation we could've added, but it is the most symmetric.
rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)
L_max = 5
Rs = [(1, L) for L in range(L_max + 1)]
sum_Ls = sum(2 * L + 1 for mult, L in Rs)
# Random spherical tensor up to L_Max
signal = torch.zeros(sum_Ls)
signal[0] = 1
# Breaking x and y symmetry with x^2 - y^2
signal[8] = -0.1
sphten = SphericalTensor(signal, 1, L_max)
r, f = sphten.plot(relu=False, n=60)
trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy())
fig.add_trace(trace, row=1, col=1)
fig.show()
It's a bit of a complicated story, but at the surface level here it is: Character tables are handy tabulations of how certain spherical tensor datatypes transform under that group symmetry. The rows are irreducible representations (irrep for short) and the columns are similar elements of the group (called conjugacy classes). Character tables are most commonly seen for finite groups of $E(3)$ symmetry as they are used extensively in solid state physics, crystallography, chemistry, etc. Next to the part of the table with the "characters", there are often columns showing linear, quadratic, and cubic functions (meaning they are of order 1, 2, and 3) that transform in the same way as a given irrep.
So, a square has a point group symmetry of $D_{4h}$ while a rectangle has a point group symmetry of $D_{2h}$
If we look at column headers of character tables for $D_{4h}$ and $D_{2h}$...
... we can see that the irrep $B_{1g}$ of $D_{4h}$ that has -1's in the columns for all the symmetry operations that $D_{2h}$ DOESN'T have and if we look down that row to the column "quadratic functions" we see, voila $x^2 - y^2$. So, to break all those symmetries that $D_{4h}$ has that $D_{2h}$ DOESN'T have -- we add a non-zero contribution to the $x^2 - y^2$ component of our spherical harmonic tensors.
Again, in this example (because we are choosing to leave out parity), we are only sensitive to the fact that the square has $C_4$ symmetry while the rectangle has $C_2$ symmetry. However, you can check the character tables for the point groups $C_4$ and $C_2$ to see that the arguement above still holds for the $x^2 - y^2$ order parameter.
In this task, our input is four points in the shape of a square with simple scalars (1.0) AND then we LEARN how to change the inputs to break symmetry such that we can fit a better model.
model = Network()
params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1. # batch, point, channel
displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
input.requires_grad = True
input_optimizer = torch.optim.Adam([input], 1e-3)
input_loss_fn = torch.nn.MSELoss()
displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
projections = projections.unsqueeze(0)
iterations = 201
for i in range(iterations):
output = model(input, square.unsqueeze(0))
loss = loss_fn(output, projections)
if i % 30 == 0:
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iterations = 101
eps = 1e-6
for i in range(iterations):
output = model(input, square.unsqueeze(0))
loss = loss_fn(output, projections)
if i % 10 == 0:
print('model loss: ', loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
output = model(input, square.unsqueeze(0))
# This is the regular loss for the model
loss = input_loss_fn(output, projections)
# This is the loss for keeping the changes to the input small
loss += ((input[:, :, 0:1] - torch.ones_like(input[:, :, 0:1])).abs()).mean()
loss += ((input[:, :, 9:]).abs()).mean()
loss += ((input[:, :, 1:4]).abs()).mean()
# Prefer features on atoms to be the same (global parameter)
loss += ((input[:, :, 4:9] - input[:, 0, 4:9])**2).mean()
# and add a mild L1 penalty for the L=2 output.
loss += 1e-3 * ((input[:, :, 4:9]).abs()).mean()
if i % 20 == 0:
print('input loss: ', loss)
input_optimizer.zero_grad()
loss.backward()
input_optimizer.step()
round_decimal = 3
print("L=0 ")
print(input.detach().numpy().round(round_decimal)[:, :, 0])
print("L=1")
print(input.detach().numpy().round(round_decimal)[:, :, 1: 1 + 3])
print("L=2")
print(input.detach().numpy().round(round_decimal)[:, :, 4: 4 + 5])
print("L=3")
print(input.detach().numpy().round(round_decimal)[:, :, 9: 9 + 7])
print("L=4")
print(input.detach().numpy().round(round_decimal)[:, :, 16: 16 + 9])
print("L=5")
print(input.detach().numpy().round(round_decimal)[:, :, 25: 25 + 11])
fig = plot_output(square, square, input, '', '')
fig.update_layout(scene_aspectmode='data')
fig.show()